# Copyright (c) OpenMMLab. All rights reserved.
# Copyright 2023 The HuggingFace Team. All rights reserved.
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.logging import MMLogger
from mmengine.runner import set_random_seed
from PIL import Image
from tqdm.auto import tqdm

from mmagic.registry import MODELS
from mmagic.utils.typing import SampleList
from .stable_diffusion import StableDiffusion

logger = MMLogger.get_current_instance()

ModelType = Union[Dict, nn.Module]


@MODELS.register_module('sd-inpaint')
@MODELS.register_module()
class StableDiffusionInpaint(StableDiffusion):

    def __init__(self, *args, **kwargs):
        """Initializes the current class using the same parameters as its
        parent, StableDiffusion.

        This constructor is primarily a pass-through to the parent class's
        constructor. All arguments and keyword arguments provided are directly
        passed to the parent class, StableDiffusion.
        """
        super().__init__(*args, **kwargs)

    @torch.no_grad()
    def infer(self,
              prompt: Union[str, List[str]],
              image: Union[torch.FloatTensor, Image.Image] = None,
              mask_image: Union[torch.FloatTensor, Image.Image] = None,
              height: Optional[int] = None,
              width: Optional[int] = None,
              num_inference_steps: int = 50,
              guidance_scale: float = 7.5,
              negative_prompt: Optional[Union[str, List[str]]] = None,
              num_images_per_prompt: Optional[int] = 1,
              eta: float = 0.0,
              generator: Optional[torch.Generator] = None,
              latents: Optional[torch.FloatTensor] = None,
              show_progress=True,
              seed=1,
              return_type='image'):
        """Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`):
                The prompt or prompts to guide the image generation.
            image (`Union[torch.FloatTensor, Image.Image]`):
                The image to inpaint.
            mask_image (`Union[torch.FloatTensor, Image.Image]`):
                The mask to apply to the image, i.e. regions to inpaint.
            height (`int`, *optional*,
                defaults to self.unet_sample_size * self.vae_scale_factor):
                The height in pixels of the generated image.
            width (`int`, *optional*,
                defaults to self.unet_sample_size * self.vae_scale_factor):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps.
                More denoising steps usually lead to a higher
                quality image at the expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                Guidance scale as defined in
                [Classifier-Free Diffusion Guidance]
                (https://arxiv.org/abs/2207.12598).
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation.
                Ignored when not using guidance (i.e., ignored
                if `guidance_scale` is less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper:
                https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator`, *optional*):
                A [torch generator] to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents,
                sampled from a Gaussian distribution,
                to be used as inputs for image generation.
                Can be used to tweak the same generation
                with different prompts.
                If not provided, a latents tensor will be
                generated by sampling using the supplied random `generator`.
            return_type (str): The return type of the inference results.
                Supported types are 'image', 'numpy', 'tensor'. If 'image'
                is passed, a list of PIL images will be returned. If 'numpy'
                is passed, a numpy array with shape [N, C, H, W] will be
                returned, and the value range will be same as decoder's
                output range. If 'tensor' is passed, the decoder's output
                will be returned. Defaults to 'image'.

        Returns:
            dict: A dict containing the generated images.
        """
        assert return_type in ['image', 'tensor', 'numpy']
        set_random_seed(seed=seed)

        # 0. Default height and width to unet
        height = height or self.unet_sample_size * self.vae_scale_factor
        width = width or self.unet_sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(prompt, height, width)

        # 2. Define call parameters
        batch_size = 1 if isinstance(prompt, str) else len(prompt)
        device = self.device

        img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \
            else self.vae.dtype
        latent_dtype = next(self.unet.parameters()).dtype
        # here `guidance_scale` is defined analog to the
        # guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf .
        # `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        text_embeddings = self._encode_prompt(prompt, device,
                                              num_images_per_prompt,
                                              do_classifier_free_guidance,
                                              negative_prompt)

        # 4. Prepare timesteps
        self.test_scheduler.set_timesteps(num_inference_steps)
        timesteps = self.test_scheduler.timesteps

        # 5. Prepare mask and image
        mask, masked_image = prepare_mask_and_masked_image(
            image, mask_image, height, width)

        # 6. Prepare latent variables
        if hasattr(self.unet, 'module'):
            num_channels_latents = self.vae.module.latent_channels
            num_channels_unet = self.unet.module.in_channels
        else:
            num_channels_latents = self.vae.latent_channels
            num_channels_unet = self.unet.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            text_embeddings.dtype,
            device,
            generator,
            latents,
        )

        # 7. Prepare masked image latents
        mask, masked_image_latents = self.prepare_mask_latents(
            mask,
            masked_image,
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            text_embeddings.dtype,
            device,
            generator,
            do_classifier_free_guidance,
        )

        # 8. Check that sizes of mask, masked image and latents match
        if num_channels_unet == 9:
            # default case for runwayml/stable-diffusion-inpainting
            num_channels_mask = mask.shape[1]
            num_channels_masked_image = masked_image_latents.shape[1]
            total_channels = num_channels_latents + \
                num_channels_masked_image + num_channels_mask
            if total_channels != self.unet.in_channels:
                raise ValueError(
                    'Incorrect configuration settings! The config of '
                    f'`pipeline.unet`: {self.unet.config} expects'
                    f' {self.unet.in_channels} but received '
                    f'`num_channels_latents`: {num_channels_latents} +'
                    f' `num_channels_mask`: {num_channels_mask} + '
                    '`num_channels_masked_image`: '
                    f'{num_channels_masked_image} = {total_channels}.'
                    'Please verify the config of `pipeline.unet` '
                    'or your `mask_image` or `image` input.')
        elif num_channels_unet != 4:
            raise ValueError(
                f'The unet {self.unet.__class__} should have either 4 or 9 '
                f'input channels, not {self.unet.config.in_channels}.')

        # 9. Prepare extra step kwargs.
        # TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 10. Denoising loop
        if show_progress:
            timesteps = tqdm(timesteps)
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat(
                [latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.test_scheduler.scale_model_input(
                latent_model_input, t)

            # concat latents with mask
            if num_channels_unet == 9:
                latent_model_input = torch.cat(
                    [latent_model_input, mask, masked_image_latents], dim=1)
            latent_model_input = latent_model_input.to(latent_dtype)
            text_embeddings = text_embeddings.to(latent_dtype)
            # predict the noise residual
            noise_pred = self.unet(
                latent_model_input, t,
                encoder_hidden_states=text_embeddings)['sample']

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.test_scheduler.step(
                noise_pred, t, latents, **extra_step_kwargs)['prev_sample']

            if num_channels_unet == 4:
                assert NotImplementedError

        # 8. Post-processing
        image = self.decode_latents(latents.to(img_dtype))
        if return_type == 'image':
            image = self.output_to_pil(image)
        elif return_type == 'numpy':
            image = image.cpu().numpy()
        else:
            assert return_type == 'tensor', (
                'Only support \'image\', \'numpy\' and \'tensor\' for '
                f'return_type, but receive {return_type}')

        return {'samples': image}

    def prepare_mask_latents(self, mask, masked_image, batch_size,
                             num_channels_latents, height, width, dtype,
                             device, generator, do_classifier_free_guidance):
        """prepare latents for diffusion to run in latent space.

        Args:
            mask (torch.Tensor): The mask to apply to the image, i.e. regions
                to inpaint.
            image (torch.Tensor): The image to be masked.
            batch_size (int): batch size.
            num_channels_latents (int): latent channel nums.
            height (int): image height.
            width (int): image width.
            dtype (torch.dtype): float type.
            device (torch.device): torch device.
            generator (torch.Generator):
                generator for random functions, defaults to None.
            latents (torch.Tensor):
                Pre-generated noisy latents, defaults to None.
            do_classifier_free_guidance (bool): Whether to apply
                classifier-free guidance.

        Return:
            latents (torch.Tensor): prepared latents.
        """
        shape = (batch_size, num_channels_latents,
                 height // self.vae_scale_factor,
                 width // self.vae_scale_factor)
        mask = F.interpolate(
            mask, size=shape[2:]).to(
                device=device, dtype=dtype)
        masked_image = masked_image.to(device=device, dtype=dtype)

        masked_image_latents = self.vae.encode(
            masked_image).latent_dist.sample(generator)
        masked_image_latents = self.vae.config.scaling_factor * \
            masked_image_latents

        # duplicate mask and masked_image_latents for each generation per
        # prompt, using mps friendly method
        if mask.shape[0] < batch_size:
            if not batch_size % mask.shape[0] == 0:
                raise ValueError(
                    "The passed mask and the required batch size don't match."
                    'Masks are supposed to be duplicated to a total batch'
                    f' size of {batch_size}, but {mask.shape[0]} masks were '
                    'passed. Make sure the number of masks that you pass'
                    ' is divisible by the total requested batch size.')
            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
        if masked_image_latents.shape[0] < batch_size:
            if not batch_size % masked_image_latents.shape[0] == 0:
                raise ValueError(
                    "The passed images and the required batch size don't "
                    'match. Images are supposed to be duplicated to a total'
                    f' batch size of {batch_size}, but '
                    f'{masked_image_latents.shape[0]} images were passed.'
                    ' Make sure the number of images that you pass is'
                    'divisible by the total requested batch size.')
            masked_image_latents = masked_image_latents.repeat(
                batch_size // masked_image_latents.shape[0], 1, 1, 1)

        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
        masked_image_latents = (
            torch.cat([masked_image_latents] * 2)
            if do_classifier_free_guidance else masked_image_latents)

        # aligning device to prevent device errors when concatenating it with
        # the latent model input
        masked_image_latents = masked_image_latents.to(
            device=device, dtype=dtype)
        return mask, masked_image_latents

    @torch.no_grad()
    def val_step(self, data: dict) -> SampleList:
        """Performs a validation step on the provided data.

        This method is decorated with `torch.no_grad()` which indicates no
        gradients will be computed during the operations. This ensures
        efficient memory usage during testing.

        Args:
            data (dict): Dictionary containing input data for testing.

        Returns:
            SampleList: List of samples processed during the testing step.

        Raises:
            NotImplementedError: This method has not been implemented.
        """
        raise NotImplementedError

    @torch.no_grad()
    def test_step(self, data: dict) -> SampleList:
        """Performs a testing step on the provided data.

        This method is decorated with `torch.no_grad()` which indicates no
        gradients will be computed during the operations. This ensures
        efficient memory usage during testing.

        Args:
            data (dict): Dictionary containing input data for testing.

        Returns:
            SampleList: List of samples processed during the testing step.

        Raises:
            NotImplementedError: This method has not been implemented.
        """
        raise NotImplementedError

    def train_step(self, data, optim_wrapper_dict):
        """Performs a training step on the provided data.

        Args:
            data: Input data for training.
            optim_wrapper_dict: Dictionary containing optimizer wrappers
                which may contain optimizers, schedulers, etc. required
                for the training step.

        Raises:
            NotImplementedError: This method has not been implemented.
        """
        raise NotImplementedError


def prepare_mask_and_masked_image(image: torch.Tensor,
                                  mask: torch.Tensor,
                                  height: int = 512,
                                  width: int = 512,
                                  return_image: bool = False):
    """Prepare latents for diffusion to run in latent space.

    Args:
        image (torch.Tensor): The image to be masked.
        mask (torch.Tensor): The mask to apply to the image, i.e. regions
            to inpaint.
        height (int): Image height.
        width (int): Image width.
        return_image (bool): Whether to return the original image.
            Default to `False`.

    Returns:
        mask (torch.Tensor): A binary mask image.
        masked_image (torch.Tensor): An image that applied mask.
    """

    if image is None:
        raise ValueError('`image` input cannot be undefined.')

    if mask is None:
        raise ValueError('`mask_image` input cannot be undefined.')

    if isinstance(image, torch.Tensor):
        if not isinstance(mask, torch.Tensor):
            raise TypeError('`image` is a torch.Tensor but `mask` (type: '
                            f'{type(mask)} is not')

        # Batch single image
        if image.ndim == 3:
            assert image.shape[
                0] == 3, 'Image outside a batch should be of shape (3, H, W)'
            image = image.unsqueeze(0)

        # Batch and add channel dim for single mask
        if mask.ndim == 2:
            mask = mask.unsqueeze(0).unsqueeze(0)

        # Batch single mask or add channel dim
        if mask.ndim == 3:
            # Single batched mask, no channel dim or single mask
            # not batched but channel dim
            if mask.shape[0] == 1:
                mask = mask.unsqueeze(0)

            # Batched masks no channel dim
            else:
                mask = mask.unsqueeze(1)

        assert (image.ndim == 4
                and mask.ndim == 4), 'Image and Mask must have 4 dimensions'
        assert image.shape[-2:] == mask.shape[
            -2:], 'Image and Mask must have the same spatial dimensions'
        assert image.shape[0] == mask.shape[
            0], 'Image and Mask must have the same batch size'

        # Check image is in [-1, 1]
        if image.min() < -1 or image.max() > 1:
            raise ValueError('Image should be in [-1, 1] range')

        # Check mask is in [0, 1]
        if mask.min() < 0 or mask.max() > 1:
            raise ValueError('Mask should be in [0, 1] range')

        # Binarize mask
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1

        # Image as float32
        image = image.to(dtype=torch.float32)
    elif isinstance(mask, torch.Tensor):
        raise TypeError(
            f'`mask` is a torch.Tensor but `image` (type: {type(image)} is not'
        )
    else:
        # preprocess image
        if isinstance(image, (Image.Image, np.ndarray)):
            image = [image]
        if isinstance(image, list) and isinstance(image[0], Image.Image):
            # resize all images w.r.t passed height an width
            image = [
                i.resize((width, height), resample=Image.LANCZOS)
                for i in image
            ]
            image = [np.array(i.convert('RGB'))[None, :] for i in image]
            image = np.concatenate(image, axis=0)
        elif isinstance(image, list) and isinstance(image[0], np.ndarray):
            image = np.concatenate([i[None, :] for i in image], axis=0)

        image = image.transpose(0, 3, 1, 2)
        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

        # preprocess mask
        if isinstance(mask, (Image.Image, np.ndarray)):
            mask = [mask]

        if isinstance(mask, list) and isinstance(mask[0], Image.Image):
            mask = [
                i.resize((width, height), resample=Image.LANCZOS) for i in mask
            ]
            mask = np.concatenate(
                [np.array(m.convert('L'))[None, None, :] for m in mask],
                axis=0)
            mask = mask.astype(np.float32) / 255.0
        elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
            mask = np.concatenate([m[None, None, :] for m in mask], axis=0)

        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1
        mask = torch.from_numpy(mask)

    masked_image = image * (mask < 0.5)

    # n.b. ensure backwards compatibility as old function does not return image
    if return_image:
        return mask, masked_image, image

    return mask, masked_image
